import copy

import torch
import numpy as np
import models
import metrics
import datasets
import tqdm
import time
import optimizers
import exp_configs
import ray
import pickle


def run_experiment(exp_dict, savedir_base="./results", datadir="./data"):
    # set seed
    # ---------------
    seed = 42 + exp_dict['runs']
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Dataset
    # -----------

    # Load Train Dataset
    train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                     train_flag=True,
                                     datadir=datadir,
                                     exp_dict=exp_dict)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               drop_last=True,
                                               shuffle=True,
                                               batch_size=exp_dict["batch_size"])

    # Load Val Dataset
    val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                   train_flag=False,
                                   datadir=datadir,
                                   exp_dict=exp_dict)

    # Model
    # -----------
    model = models.get_model(exp_dict["model"],
                             train_set=train_set)
    # Choose loss and metric function
    loss_function = metrics.get_metric_function(exp_dict["loss_func"])

    # Load Optimizer
    n_batches_per_epoch = len(train_set) / float(exp_dict["batch_size"])

    opt = optimizers.get_optimizer(opt=exp_dict["opt"],
                                   params=model.parameters(),
                                   n_batches_per_epoch=n_batches_per_epoch)

    # Train & Val
    # ------------
    s_epoch = 0
    print('Starting experiment at epoch %d/%d' % (s_epoch, exp_dict['max_epoch']))

    prev_iteration = 0
    cur_iteration = 0
    prev_model = copy.deepcopy(model)
    num_iterations_per_test = exp_dict['num_iterations_per_test']

    result_dict = {"iteration": [], "train_loss": [], "val_acc": [], "params": opt.get_param_dict(),
                   "eps_f": [], "val_loss": [], "time": [], "step_size": [], "grad_norm": []}

    result_dict["iteration"].append(0)
    result_dict["train_loss"].append(metrics.compute_metric_on_dataset(prev_model, train_set,
                                                                    metric_name=exp_dict["loss_func"]))
    result_dict["val_acc"].append(metrics.compute_metric_on_dataset(prev_model, val_set,
                                                                 metric_name=exp_dict["acc_func"]))
    result_dict["val_loss"].append(metrics.compute_metric_on_dataset(prev_model, val_set,
                                                                    metric_name=exp_dict["loss_func"]))

    name = exp_dict['opt']['name'] if isinstance(exp_dict['opt'], dict) else exp_dict['opt']

    start_time = time.time()
    for epoch in range(s_epoch, exp_dict['max_epoch']):
        ## Estimate epsilon_f
        if name in ['nls', 'aloe']:
            eps_f = metrics.estimate_eps_f(model, train_set, metric_name=exp_dict["loss_func"],
                                           n_samples=30, sample_size=128, factor=1/5)
            # eps_f = 0
            opt.set_eps_f(eps_f)
            print("\nEstimated epsilon_f: ", eps_f)

        print("Starting epoch %d/%d" % (epoch, exp_dict['max_epoch']))
        # Set seed
        np.random.seed(exp_dict['runs'] + epoch)
        torch.manual_seed(exp_dict['runs'] + epoch)
        torch.cuda.manual_seed_all(exp_dict['runs'] + epoch)

        metrics_flag = 1

        # 3. Train over train loader
        model.train()
        print("%d - Training model with %s..." % (epoch, exp_dict["loss_func"]))

        s_time = time.time()

        for images, labels in tqdm.tqdm(train_loader):

            opt.zero_grad()

            closure = lambda: loss_function(model, images, labels, backwards=False)
            opt.step(closure)

            prev_iteration = cur_iteration
            cur_iteration = opt.state['n_forwards']

            if metrics_flag and cur_iteration // num_iterations_per_test > prev_iteration // num_iterations_per_test:

                # 1. Compute train loss over train set
                prev_train_loss = metrics.compute_metric_on_dataset(prev_model, train_set,
                                                                    metric_name=exp_dict["loss_func"])

                # 2. Compute val acc over val set
                prev_val_acc = metrics.compute_metric_on_dataset(prev_model, val_set,
                                                                 metric_name=exp_dict["acc_func"])

                # 3. Compute test loss over val set
                prev_val_loss = metrics.compute_metric_on_dataset(prev_model, val_set,
                                                                  metric_name=exp_dict["loss_func"])

                prev_recorded_iter = result_dict["iteration"][-1]
                new_iters_to_record = list(range(prev_recorded_iter + num_iterations_per_test, cur_iteration, num_iterations_per_test))
                result_dict["iteration"] += new_iters_to_record
                result_dict["train_loss"] += len(new_iters_to_record) * [prev_train_loss]
                result_dict["val_acc"] += len(new_iters_to_record) * [prev_val_acc]
                if name in ["aloe", "nls"]:
                    result_dict["eps_f"] += len(new_iters_to_record) * [eps_f]
                result_dict["time"] += len(new_iters_to_record) * [time.time() - start_time]
                result_dict["val_loss"] += len(new_iters_to_record) * [prev_val_loss]
                result_dict["step_size"] += len(new_iters_to_record) * [opt.state["step_size"]]
                result_dict["grad_norm"] += len(new_iters_to_record) * [opt.state["grad_norm"]]

                print("Current iteration ", cur_iteration)
                print("Previous iteration ", prev_iteration)
                print("Recording iterations ", new_iters_to_record)
                print("Train loss", prev_train_loss)
                print("Val acc", prev_val_acc)

            prev_model = copy.deepcopy(model)

        if cur_iteration > exp_dict['max_iteration']:
            break

    print('Experiment completed')

    gam_name = ''
    if name in ['nls', 'aloe'] and 'gamma_incr' in exp_dict['opt']:
        gam_name = 'gd_{0}_gi_{1}_'.format(exp_dict['opt']['gamma_decr'], exp_dict['opt']['gamma_incr'])

    fname = exp_dict['dataset'] + '_' + exp_dict['model'] + '_' + name + '_' + gam_name + 'dict_run_{0}_std_0.2.pkl'.format(exp_dict["runs"])
    f = open(savedir_base + '/' + fname, "wb")
    pickle.dump(result_dict, f)
    f.close()


if __name__ == '__main__':
    exp_group_list = ["mnist_nls_large"]
    exp_list = []
    for exp_group_name in exp_group_list:
        exp_list += exp_configs.EXP_GROUPS[exp_group_name]

    savedir_base = "./results"
    datadir = "./data"
    # Run experiments
    # ----------------------------
    parallel = False

    if parallel:
        ray.init()

        @ray.remote
        def parallel_wrapper(exp_dict):
            # do trainval
            run_experiment(exp_dict=exp_dict,
                     savedir_base=savedir_base,
                     datadir=datadir)

        result_ids = []
        for exp_dict in exp_list:
            result_ids.append(parallel_wrapper.remote(exp_dict))

        results = ray.get(result_ids)
        ray.shutdown()

    else:
        for exp_dict in exp_list:
            # do trainval
            run_experiment(exp_dict=exp_dict,
                     savedir_base=savedir_base,
                     datadir=datadir)
